"""Fully connected network."""
import numpy as np
from datetime import datetime 

import torch
import torch.nn as nn
import torch.nn.functional as F

BASE_WIDTH = 256


class FCN(nn.Module):

    def __init__(self, input_size, n_classes, depth=2, width=1):
        super(FCN, self).__init__()
        self.input_size = input_size
        self.n_classes = n_classes
        self.depth = depth
        self.width = width
        self.flatten = torch.nn.Flatten()
        self.hidden_layers = self._get_parameterized_layers(depth, width)
        self.output_layer = nn.Linear(in_features=int(BASE_WIDTH * width), out_features=n_classes)
        self.all_layers = self._get_all_layers(self.hidden_layers, self.output_layer)
        self.classifier = nn.Sequential(
            *self.all_layers
        )

    @property    
    def last_layer_name(self):
        return 'classifier.{}'.format(self.depth)

    def get_last_layer_weights(self):
        state_dict = self.state_dict()
        print(state_dict)
        n = self.last_layer_name
        return (state_dict['output_layer.weight'],  state_dict['output_layer.bias'])

    def forward(self, x):
        x = self.flatten(x)
        return self.classifier(x)

    def _get_parameterized_layers(self, depth, width):
        layers = []
        n_in = self.input_size
        for _ in range(depth):
            print('n_in', n_in, 'n_out', int(BASE_WIDTH * width))
            layers.append(nn.Linear(in_features=n_in, out_features=int(BASE_WIDTH * width)))
            n_in = int(BASE_WIDTH * width)
        return layers

    def _get_all_layers(self, hidden_layers, output_layer):
        activation = [nn.ReLU() for _ in range(len(hidden_layers))]
        all_layers = []
        for h, a in zip(hidden_layers, activation):
            all_layers += [h, a]
        all_layers.append(output_layer)
        return all_layers

    # def combine_model(self, model_1, mode1_2, breakpoint=-1):
    #     """Glue different parts of two models together."""
    #     state_dict_1 = model_1.state_dict()
    #     state_dict_2 = mode1_2.state_dict()
    #     new_state_dict = state_dict_2.copy()

    #     all_weight_names = list(state_dict_1.keys())
    #     weight_name, bias_name = all_weight_names[breakpoint-1], all_weight_names[breakpoint]
    #     new_state_dict[weight_name] =  state_dict_1[weight_name] 
    #     new_state_dict[bias_name] =  state_dict_1[bias_name]
    #     self.load_state_dict(new_state_dict)

    def combine_model(self, model_1, model_2, breakpoint=-1):
        """Glue different parts of two models together."""
        model1_all_layers = model_1.all_layers
        model2_all_layers = model_2.all_layers
        model2_all_layers[-1] = model1_all_layers[-1]
        self.classifier = nn.Sequential(
            *model2_all_layers
        )

    def get_intermediate(self, x):
        lookup = {}
        out = self.flatten(x)
        names = ['layer{}'.format(i) for i in range(len(self.all_layers))]
        names[-2] = 'avg_pool'
        for n, m in zip(names, self.all_layers):
            out = m(out)
            lookup[n] = out
        return lookup
